/*____________________________________________________________________________
	Copyright (C) 2000 Networks Associates Technology, Inc.
	All rights reserved.
	
	$Id: pgpPassCach.c,v 1.22 2001/01/25 22:12:01 jeffc Exp $
____________________________________________________________________________*/

/*
 * A Passphrase Cache -- keep a cache of passphrases and
 * try them against ESKs and Secret Keys.  This way you only have to
 * type your passphrase once, it gets cached, and then it gets used
 * again and again.
 */

#include <string.h>
#include "pgpConfig.h"
#include <stdio.h>

#include "pgpMem.h"
#include "pgpPassCach.h"
#include "pgpErrors.h"
#include "pgpPubKey.h"
#include "pgpEnv.h"
#include "pgpContext.h"
#include "pgpKeyPriv.h"
#include "pgpThreads.h"

#define FOREVER		MAX_PGPUInt32

struct PGPCacheHeader {
	PGPContextRef	context;
	struct PGPPassCache	*globalCache;
	struct PGPPassCache	*localCaches;
};

struct PGPPassCache {
	struct PGPPassCache *next;
	PGPConnectRef	clientID;
	PGPBoolean		globalCache;
	PGPKeyID		keyID;
	PGPUInt32		timeoutInterval; /* timeout interval in secs */
	PGPTime			timeoutTime;      /* absolute time this entry expires */
	PGPByte *		hashedPhrase;
	PGPSize			hashedPhraseLength;
	PGPByte *		userName;		/* only used on global cache */
};

PGPMutex_t gPassCacheMutex;

/* Create the passphrase cache header and add to the context */
	static PGPError
sInitPassphraseCache( PGPContextRef context )
{
	struct PGPCacheHeader *	cacheHeader;

	cacheHeader = pgpContextMemAlloc( context, sizeof(*cacheHeader), 0 );
	if( IsNull( cacheHeader ) )
		return kPGPError_OutOfMemory;

	pgpClearMemory( cacheHeader, sizeof(*cacheHeader) );

	cacheHeader->context = context;
	pgpContextSetPassphraseCache( context, cacheHeader);
	return kPGPError_NoErr;
}


/* Destroy data in a given passphrase cache entry */
	static void
sDestroyPassCache( struct PGPPassCache *cache )
{
	if( IsntNull( cache->hashedPhrase ) )
	{
		pgpClearMemory( cache->hashedPhrase, cache->hashedPhraseLength );
		PGPFreeData( cache->hashedPhrase );
	}
	if( IsntNull( cache->userName ) )
	{
		PGPFreeData( cache->userName );
	}
	pgpClearMemory( cache, sizeof(*cache) );
	PGPFreeData( cache );
}

/* Given a pointer to a pointer to the cache, destroy and remove it */
	static void
sExpirePassCache( struct PGPPassCache **cachePtr )
{
	struct PGPPassCache *	cache;
	struct PGPPassCache *	cachenext;

	cache = *cachePtr;
	cachenext = cache->next;
	sDestroyPassCache( cache );
	*cachePtr = cachenext;
}
	
	PGPError
pgpPurgePassphraseCache_internal( PGPContextRef context )
{
	struct PGPCacheHeader *	cacheHeader;
	struct PGPPassCache *	cache;
	struct PGPPassCache *	nextcache;

	PGPMutexLock(&gPassCacheMutex);

	cacheHeader = pgpContextGetPassphraseCache( context );
	if( IsntNull( cacheHeader ) )
	{
		cache = cacheHeader->globalCache;
		while( IsntNull(cache) )
		{
			nextcache = cache->next;
			sDestroyPassCache( cache );
			cache = nextcache;
		}

		cache = cacheHeader->localCaches;
		while( IsntNull(cache) )
		{
			nextcache = cache->next;
			sDestroyPassCache( cache );
			cache = nextcache;
		}

		cacheHeader->globalCache = NULL;
		cacheHeader->localCaches = NULL;

		pgpClearMemory( cacheHeader, sizeof(*cacheHeader) );
		PGPFreeData( cacheHeader );
		pgpContextSetPassphraseCache( context, NULL);
	}

	PGPMutexUnlock(&gPassCacheMutex);
	return kPGPError_NoErr;
}




/* Put the given passphrase into the cache for the given keyID */
	static PGPError
sCachePassphraseInternal(PGPContextRef context, PGPKeyID keyID,
	PGPByte const *passPhrase, PGPSize passLength,
	PGPUInt32 cacheTimeOut, PGPBoolean globalCache)
{
	struct PGPCacheHeader *	cacheHeader;
	struct PGPPassCache *	cache;
	PGPConnectRef			clientID = kPGPConnectRef_Null;
	PGPTime					now;
	PGPByte *				userName = NULL;
	PGPSize					userNameLength = 0;

	cacheHeader = pgpContextGetPassphraseCache( context );
	if( IsNull( cacheHeader ) )
	{
		if( IsPGPError( sInitPassphraseCache( context ) ) )
			return kPGPError_NoErr;
		cacheHeader = pgpContextGetPassphraseCache( context );
	}

	if( globalCache )
	{
		cache = cacheHeader->globalCache;
#if PGP_WIN32
		{
			pgpRPCconnection *connectRef;
			clientID = pgpContextGetConnectRef( context );
			connectRef = (pgpRPCconnection *) clientID;
			if( IsntNull( connectRef ) && IsntNull( connectRef->UserName ) )
			{
				userName = connectRef->UserName;
				userNameLength = strlen( userName );
			}
		}
#endif
	} else {
		clientID = pgpContextGetConnectRef( context );
		cache = cacheHeader->localCaches;
	}

	while( IsntNull(cache) )
	{
		pgpAssert( !cache->globalCache == !globalCache );
		if( pgpConnectRefEqual(cache->clientID, clientID) &&
			(IsNull(userName) || 0 == strcmp((char *) userName, (char *) cache->userName) ) &&
			PGPCompareKeyIDs( &cache->keyID, &keyID ) == 0 )
		{
			/* Found matching entry */
			break;
		}
		cache = cache->next;
	}

	if( IsNull( cache ) )
	{
		cache = (struct PGPPassCache *) pgpContextMemAlloc( context,
												sizeof (*cache), 0 );
		if( IsNull( cache ) )
			return kPGPError_OutOfMemory;
		
		pgpClearMemory( cache, sizeof (*cache) );
		cache->globalCache = !!globalCache;
		cache->clientID = clientID;
		cache->keyID = keyID;
		if( globalCache )
		{
			cache->next = cacheHeader->globalCache;
			cacheHeader->globalCache = cache;
			if( IsntNull( userName ) )
			{
				cache->userName = (PGPByte *) pgpContextMemAlloc( context,
													userNameLength+1, 0 );
				if( IsntNull( cache->userName ) )
					pgpCopyMemory( userName, cache->userName,
								   userNameLength+1 );	/* null terminate */
			}
		} else {
			cache->next = cacheHeader->localCaches;
			cacheHeader->localCaches = cache;
		}
	}

	if( IsntNull( cache->hashedPhrase ) )
		PGPFreeData( cache->hashedPhrase );

	cache->hashedPhrase = (PGPByte *)PGPNewSecureData(
						PGPPeekContextMemoryMgr( context ), passLength, 0 );
	pgpCopyMemory( passPhrase, cache->hashedPhrase, passLength );
	cache->hashedPhraseLength = passLength;

	cache->timeoutInterval = cacheTimeOut;
	now = PGPGetTime();
	cache->timeoutTime = now + cache->timeoutInterval;
	if( cache->timeoutTime < now )
		cache->timeoutTime = FOREVER;
	return kPGPError_NoErr;
}


/* Retrieve the cached passphrase if any for the given keyID */
	static PGPError
sRetrieveCachedPassphraseInternal(PGPContextRef context, PGPKeyID keyID,
	PGPByte const **passPhrase, PGPSize *passLength)
{
	struct PGPCacheHeader *	cacheHeader;
	struct PGPPassCache *	cache;
	PGPConnectRef			clientID;
	PGPTime					now;
	PGPByte *				userName = NULL;

	*passPhrase = NULL;
	*passLength = 0;

	cacheHeader = pgpContextGetPassphraseCache( context );
	if( IsNull( cacheHeader ) )
	{
		if( IsPGPError( sInitPassphraseCache( context ) ) )
			return kPGPError_NoErr;
		cacheHeader = pgpContextGetPassphraseCache( context );
	}

	clientID = pgpContextGetConnectRef( context );
#if PGP_WIN32
	{
		pgpRPCconnection *connectRef = (pgpRPCconnection *) clientID;
		if( IsntNull( connectRef ) && IsntNull( connectRef->UserName ) )
		{
			userName = connectRef->UserName;
		}
	}
#endif

	/*
	 * Search for matching cache entries.  Skip those which have expired,
	 * they will be removed shortly by the expiration thread.
	 */
	now = PGPGetTime();
	cache = cacheHeader->globalCache;
	while( IsntNull(cache) )
	{
		if( PGPCompareKeyIDs( &cache->keyID, &keyID) == 0  &&
			(IsNull(userName) || 0 == strcmp((char *) userName,
											 (char *) cache->userName) ) &&
			cache->timeoutTime >= now )
		{
			*passLength = cache->hashedPhraseLength;
			*passPhrase = cache->hashedPhrase;
			cache->timeoutTime = now + cache->timeoutInterval;
			if( cache->timeoutTime < now )
				cache->timeoutTime = FOREVER;
			return kPGPError_NoErr;
		}
		cache = cache->next;
	}

	cache = cacheHeader->localCaches;
	while( IsntNull(cache) )
	{
		if( pgpConnectRefEqual(cache->clientID, clientID) &&
			PGPCompareKeyIDs( &cache->keyID, &keyID ) == 0 &&
			cache->timeoutTime >= now )
		{
			*passLength = cache->hashedPhraseLength;
			*passPhrase = cache->hashedPhrase;
			cache->timeoutTime = now + cache->timeoutInterval;
			if( cache->timeoutTime < now )
				cache->timeoutTime = FOREVER;
			return kPGPError_NoErr;
		}
		cache = cache->next;
	}
	return kPGPError_BadPassphrase;
}




/*
 * Unlock the given secret key, using the specified passphrase or one
 * from the cache if available.  Also cache the passphrase if requested.
 */
	PGPError
pgpSecKeyUnlockWithCache( PGPSecKey *sec, PGPByte const *passphrase,
	PGPSize passphraseLength, PGPBoolean hashedPhrase, PGPUInt32 cacheTimeOut,
	PGPBoolean cacheGlobal )
{
	PGPKeyID keyID;
	PGPContextRef context;
	PGPEnv *pgpEnv;
	PGPByte *passBuffer;
	PGPSize bufsize;
	PGPError err = kPGPError_NoErr;
	PGPBoolean usedCache = FALSE;
	PGPBoolean checkedPhrase = FALSE;
	PGPBoolean mustFree = FALSE;

	PGPMutexLock(&gPassCacheMutex);

	context = sec->context;
	pgpEnv = pgpContextGetEnvironment( context );
	pgpNewKeyIDFromRawData( sec->keyID, sec->pkAlg, 8, &keyID );

	if( passphraseLength == 0 )
		passphrase = NULL;

	if (pgpSecKeyIslocked (sec)) {
		if (IsNull( passphrase )) {
			err = sRetrieveCachedPassphraseInternal( context, keyID,
											&passphrase, &passphraseLength );
			if( IsPGPError( err ) ) {
				PGPMutexUnlock(&gPassCacheMutex);
				return err;
			}
			hashedPhrase = TRUE;
			usedCache = TRUE;
		}
		err = (PGPError)pgpSecKeyUnlock (sec, (const char *)passphrase, passphraseLength,
										 hashedPhrase);
		if (err != 1) {
			if (err == 0)
				err = kPGPError_BadPassphrase;
			PGPMutexUnlock(&gPassCacheMutex);
			return err;
		}
		err = kPGPError_NoErr;
		checkedPhrase = TRUE;
	} else {
		PGPMutexUnlock(&gPassCacheMutex);
		return IsntNull( passphrase ) ? kPGPError_BadPassphrase
									  : kPGPError_NoErr;
	}

	/* If we got here, we have successfully unlocked the key */
	if( checkedPhrase && !usedCache && cacheTimeOut > 0 )
	{
		/* Add to cache */
		if( !hashedPhrase )
		{
			err = pgpSecKeyLockingalgorithm( sec, NULL, &bufsize );
			if( IsPGPError( err ) ) {
				PGPMutexUnlock(&gPassCacheMutex);
				return kPGPError_NoErr;
			}
			passBuffer = (PGPByte *)PGPNewSecureData(
							PGPPeekContextMemoryMgr( context ), bufsize, 0 );
			err =  pgpSecKeyConvertPassphrase( sec, pgpEnv, (const char *)passphrase,
											   passphraseLength, passBuffer );
			if( IsPGPError( err ) )
			{
				PGPFreeData( passBuffer );
				PGPMutexUnlock(&gPassCacheMutex);
				return kPGPError_NoErr;
			}
			passphrase = passBuffer;
			passphraseLength = bufsize;
			hashedPhrase = TRUE;
			mustFree = TRUE;
		}
		pgpAssert( hashedPhrase );
		err = sCachePassphraseInternal( context, keyID, passphrase,
						passphraseLength, cacheTimeOut, cacheGlobal );
		if( mustFree )
			PGPFreeData( (PGPByte *)passphrase );
	}
	PGPMutexUnlock(&gPassCacheMutex);
	return err;
}


/* Put the passphrase for the given key into the cache, if valid */
	PGPError
pgpSecKeyCachePassphrase( PGPSecKey *sec, PGPByte const *passphrase,
	PGPSize passphraseLength, PGPBoolean hashedPhrase, PGPUInt32 cacheTimeOut,
	PGPBoolean cacheGlobal )
{
	PGPContextRef context;
	PGPEnv *pgpEnv;
	PGPSize bufsize;
	PGPByte *passBuffer;
	PGPKeyID keyID;
	PGPBoolean mustFree = FALSE;
	PGPError err = kPGPError_NoErr;

	context = sec->context;
	pgpEnv = pgpContextGetEnvironment( context );
	PGPMutexLock(&gPassCacheMutex);
	if( passphraseLength == 0 )
		passphrase = NULL;

	if( IsNull( passphrase ) ) {
		PGPMutexUnlock(&gPassCacheMutex);
		return kPGPError_NoErr;
	}

	if( !hashedPhrase )
	{
		err = pgpSecKeyLockingalgorithm( sec, NULL, &bufsize );
		if( IsPGPError( err ) )
			return kPGPError_NoErr;
		passBuffer = (PGPByte *)PGPNewSecureData(
							PGPPeekContextMemoryMgr( context ), bufsize, 0 );
		err =  pgpSecKeyConvertPassphrase( sec, pgpEnv,
										   (const char *)passphrase,
										   passphraseLength, passBuffer );
		if( IsPGPError( err ) )
		{
			PGPMutexUnlock(&gPassCacheMutex);
			PGPFreeData( passBuffer );
			return kPGPError_NoErr;
		}
		passphrase = passBuffer;
		passphraseLength = bufsize;
		hashedPhrase = TRUE;
		mustFree = TRUE;
	}
	/* Add to cache */
	pgpNewKeyIDFromRawData( sec->keyID, sec->pkAlg, 8, &keyID );
	pgpAssert( hashedPhrase );
	err = sCachePassphraseInternal( context, keyID, passphrase,
							passphraseLength, cacheTimeOut, cacheGlobal );
	if( mustFree )
		PGPFreeData( (PGPByte *)passphrase );
	PGPMutexUnlock(&gPassCacheMutex);
	return err;
}


/* Call periodically to check for expiration of cached passphrases */
	void
pgpExpirePassphraseCache( PGPContextRef context )
{
	struct PGPCacheHeader *	cacheHeader;
	struct PGPPassCache **	cachePtr;
	struct PGPPassCache *	cache;
	PGPTime					now;

	PGPMutexLock(&gPassCacheMutex);

	cacheHeader = pgpContextGetPassphraseCache( context );
	if( IsNull( cacheHeader ) ) {
		PGPMutexUnlock(&gPassCacheMutex);
		return;
	}

	now = PGPGetTime();

	cachePtr = &cacheHeader->globalCache;
	while( IsntNull(*cachePtr) )
	{
		cache = *cachePtr;
		if( cache->timeoutTime < now )
			sExpirePassCache( cachePtr );
		else
			cachePtr = &cache->next;
	}

	cachePtr = &cacheHeader->localCaches;
	while( IsntNull(*cachePtr) )
	{
		cache = *cachePtr;
		if( cache->timeoutTime < now )
			sExpirePassCache( cachePtr );
		else
			cachePtr = &cache->next;
	}
	PGPMutexUnlock(&gPassCacheMutex);
}


/* Client call to clear all cached data */
	PGPError
PGPPurgePassphraseCache( PGPContextRef context )
{
	PGPValidateContext( context );

	if (pgpRPCEnabled())
		return pgpPurgePassphraseCache_back(context);

	return pgpPurgePassphraseCache_internal(context);
}


/*__Editor_settings____

	Local Variables:
	tab-width: 4
	End:
	vi: ts=4 sw=4
	vim: si
_____________________*/
